package net.csdn.business.common.config.wrapper;

import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.concurrent.ListenableFuture;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;

/**
 * @author likun
 * @date 2022/6/18 16:13
 */

public class ThreadPoolWrapper {
    public static <T> Callable<T> wrap(final Callable<T> callable, final Map<String, String> context) {
        return () -> {
            if (context == null) {
                MDC.clear();
            } else {
                MDC.setContextMap(context);
            }

            try {
                return callable.call();
            } finally {
                if (context == null) {
                    MDC.clear();
                } else {
                    MDC.setContextMap(context);
                }
            }
        };
    }

    public static Runnable wrap(final Runnable runnable, final Map<String, String> context) {
        return () -> {
            if (context == null) {
                MDC.clear();
            } else {
                MDC.setContextMap(context);
            }

            try {
                runnable.run();
            } finally {
                if (context == null) {
                    MDC.clear();
                } else {
                    MDC.setContextMap(context);
                }
            }
        };
    }

    /**
     * 自定义线程池类，完全集成原有 ThreadPoolTaskExecutor 类
     */
    @SuppressWarnings({"all"})
    public static class CustomThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {
        @Override
        public void execute(Runnable task) {
            super.execute(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()));
        }

        @Override
        public void execute(Runnable task, long startTimeout) {
            super.execute(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()), startTimeout);
        }

        @Override
        public <T> Future<T> submit(Callable<T> task) {
            return super.submit(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()));
        }

        @Override
        public Future<?> submit(Runnable task) {
            return super.submit(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()));
        }

        @Override
        public ListenableFuture<?> submitListenable(Runnable task) {
            return super.submitListenable(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()));
        }

        @Override
        public <T> ListenableFuture<T> submitListenable(Callable<T> task) {
            return super.submitListenable(ThreadPoolWrapper.wrap(task, MDC.getCopyOfContextMap()));
        }
    }
}
